Arborenv In-Need Regions ML Usecase

In [1]:
# Common
import os 
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from IPython.display import clear_output as cls

# Data 
from tqdm import tqdm
import tensorflow.data as tfd

# Data Visualization
import matplotlib.pyplot as plt

# Model Building
from tensorflow.keras import layers
from tensorflow.keras import callbacks
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from tensorflow.keras.optimizers.schedules import ExponentialDecay

# Model visualization
from tensorflow.keras.utils import plot_model

# Extra
from typing import List, Tuple, Union

Hyperparameters and constants¶

In [2]:
# Image and Mask Dimensions
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 160
N_IMAGE_CHANNELS = 3
N_MASK_CHANNELS = 1

# Image and Mask Size
IMAGE_SIZE = (IMAGE_WIDTH, IMAGE_HEIGHT, N_IMAGE_CHANNELS)
MASK_SIZE = (IMAGE_WIDTH, IMAGE_HEIGHT, N_MASK_CHANNELS)

# Batch Size and Learning Rate
BATCH_SIZE = 32
BASE_LR = 1e-2

# Model Name
MODEL_NAME = 'UNetForestSegmentation'

# Model Training
EPOCHS = 100

# Data Paths
ROOT_IMAGE_DIR = 'Forest_Segmented/images'
ROOT_MASK_DIR = 'Forest_Segmented/masks'
METADATA_CSV_PATH = 'Forest_Segmented/meta_data.csv'

# Model Architecture
FILTERS = 32
In [3]:
# Random Seed
SEED = 42

np.random.seed(SEED)
tf.random.set_seed(SEED)

Utility Functions¶

In [4]:
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[tf.Tensor, tf.Tensor]:
    
    '''
    This function takes the file paths of an image and its corresponding mask as input. It first reads the images, then decodes them into tensors, 
    and resizes them to a standard size. After that, the image and mask tensors are normalized by clipping the pixel values between 0 and 1. 
    Finally, the function converts the image and mask tensors to the float32 data type and returns them as a tuple.
    
    Arguments : 
        image_path : The path to the image to be loaded. 
        mask_path  : The path to the mask to be loaded.
    
    Returns :
        image : This is the loaded and the processed image. 
        mask  : This is the loaded and the processed mask.
    
    '''
    
    # Read the images
    image = tf.io.read_file(filename = image_path)
    mask  = tf.io.read_file(filename = mask_path)
    
    # Decode the images
    image = tf.image.decode_jpeg(contents = image, channels = N_IMAGE_CHANNELS)
    mask  = tf.image.decode_jpeg(contents = mask,  channels = N_MASK_CHANNELS)
    
    # Convert the image to a Tensor
    image = tf.image.convert_image_dtype(image = image, dtype = tf.float32)
    mask  = tf.image.convert_image_dtype(image = mask, dtype = tf.float32)
    
    # Resize the image to the desired dimensions
    image = tf.image.resize(images = image, size = (IMAGE_WIDTH, IMAGE_HEIGHT))
    mask  = tf.image.resize(images = mask, size = (IMAGE_WIDTH, IMAGE_HEIGHT))
    
    # Normalize the image
    image = tf.clip_by_value(image, clip_value_min = 0.0, clip_value_max = 1.0)
    mask  = tf.clip_by_value(mask, clip_value_min = 0.0, clip_value_max = 1.0)
    
    # Final conversion
    image = tf.cast(image, dtype = tf.float32)
    mask  = tf.cast(mask,  dtype = tf.float32)
    
    return image, mask
In [5]:
# Load CSV File
metadata = pd.read_csv(METADATA_CSV_PATH)

# Quick look
metadata.head()
Out[5]:
image mask
0 10452_sat_08.jpg 10452_mask_08.jpg
1 10452_sat_18.jpg 10452_mask_18.jpg
2 111335_sat_00.jpg 111335_mask_00.jpg
3 111335_sat_01.jpg 111335_mask_01.jpg
4 111335_sat_02.jpg 111335_mask_02.jpg
In [6]:
# Define indices
start_index = 0
end_index = 1500

# Slice metadata to get the first 2000 entries
metadata_subset = metadata.iloc[start_index:end_index].copy()

# Add root path to image file names and ensure forward slashes
metadata_subset['image'] = [os.path.normpath(os.path.join(ROOT_IMAGE_DIR, filename)).replace('\\', '/') for filename in metadata_subset['image']]

# Add root path to mask file names and ensure forward slashes
metadata_subset['mask'] = [os.path.normpath(os.path.join(ROOT_MASK_DIR, filename)).replace('\\', '/') for filename in metadata_subset['mask']]
In [7]:
# Quick Check
metadata_subset.head()
Out[7]:
image mask
0 Forest_Segmented/images/10452_sat_08.jpg Forest_Segmented/masks/10452_mask_08.jpg
1 Forest_Segmented/images/10452_sat_18.jpg Forest_Segmented/masks/10452_mask_18.jpg
2 Forest_Segmented/images/111335_sat_00.jpg Forest_Segmented/masks/111335_mask_00.jpg
3 Forest_Segmented/images/111335_sat_01.jpg Forest_Segmented/masks/111335_mask_01.jpg
4 Forest_Segmented/images/111335_sat_02.jpg Forest_Segmented/masks/111335_mask_02.jpg
In [8]:
def load_dataset(
    image_paths: list, mask_paths: list, split_ratio: float=0.7, 
    batch_size: int=BATCH_SIZE, shuffle: bool=True, 
    buffer_size: int=1000, n_repeat: int=1
) -> Union[Tuple[tfd.Dataset, tfd.Dataset], tfd.Dataset]:
    '''
    This function loads the image and mask data from the provided file paths and creates a TensorFlow dataset. The function
    first creates space to store the image and mask data in numpy arrays. It then iterates over each image and mask pair, 
    loading them using the load_image_and_mask function and storing them in the numpy arrays.
    
    The function then creates a TensorFlow dataset using the numpy arrays. If shuffle is True, it shuffles the dataset
    with a buffer size of buffer_size. If split_ratio is not None, it splits the dataset into two parts with sizes determined
    by the split_ratio, and converts them into batches of size batch_size with drop_remainder=True. The two resulting datasets
    are returned as a tuple.

    If split_ratio is None, the entire dataset is converted into batches of size batch_size with drop_remainder=True, 
    and the resulting dataset is returned.
    
    Args:
        image_paths: A list of strings, containing the file paths of the input images.
        
        mask_paths: A list of strings, containing the file paths of the corresponding mask images.
        
        split_ratio: A float value between 0 and 1, representing the ratio of data to be used for validation. 
                    If split_ratio is set to None, then no data will be split for validation.
                    
        batch_size: An integer, representing the batch size for the input data.
        
        shuffle: A boolean value indicating whether the data should be shuffled or not.
        
        buffer_size: An integer, representing the buffer size for shuffling the data.
        
        n_repeat: An integer, representing the total number of repetations of the data.
    
    Returns:
        If split_ratio is not None, then the function returns a tuple of two Tensorflow datasets. 
        The first dataset contains the training data and the second dataset contains the validation data.
        
        If split_ratio is None, then the function returns a single Tensorflow dataset containing the 
        input data batched and pre-fetched for training.
    
    '''
    
    # Create space for storing the data.
    images = np.empty(shape=(len(image_paths), *IMAGE_SIZE), dtype=np.float32)
    masks  = np.empty(shape=(len(mask_paths), *MASK_SIZE),  dtype=np.float32)
    
    # Iterate over the data.
    index = 0
    for image_path, mask_path in tqdm(zip(image_paths, mask_paths), desc='Loading'):
        
        # Load the image and the mask.
        image, mask = load_image_and_mask(image_path = image_path, mask_path = mask_path)
        
        # Store the image and the mask.
        images[index] = image
        masks[index]  = mask
        
        # Increment the index.
        index += 1
    
    # Create a Tensorflow data.
    data_set = tfd.Dataset.from_tensor_slices((images, masks)).repeat(n_repeat)
    
    # Shuffle the data set.
    if shuffle:
        data_set = data_set.shuffle(buffer_size)
    
    # Split the data 
    if split_ratio is not None:
        
        split_ratio_val_test = (1 - split_ratio)/2

        # Calculate new data sizes after splitting.
        data_1_len = int(split_ratio * len(images))
        print(data_1_len)
        data_2_len = int(split_ratio_val_test * len(images))
        print(data_2_len)
        
        # Divide the data into 2 parts.
        data_1 = data_set.take(data_1_len)
        data_2 = data_set.skip(data_1_len).take(data_2_len)
        data_3 = data_set.skip(data_1_len + data_2_len).take(data_2_len)
        
        # Convert data into batches.
        data_1 = data_1.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
        data_2 = data_2.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
        data_3 = data_3.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
        
        # Return the data 
        return data_1, data_2, data_3
    
    else:
        
        # Convert data into batches
        data_set = data_set.batch(batch_size, drop_remainder=True).prefetch(tfd.AUTOTUNE)
        
        # Return the data
        return data_set
In [9]:
# Training and Testing Data
train_ds, test_ds, valid_ds = load_dataset(
    image_paths = metadata_subset['image'],
    mask_paths = metadata_subset['mask'],
    split_ratio = 0.7,
    shuffle = True,
    n_repeat=3,
)
Loading: 1500it [00:03, 498.11it/s]
1050
225
In [10]:
print("*"*100)
print(f"{' '*30}Training Data Size : {train_ds.cardinality().numpy() * BATCH_SIZE}")
print(f"{' '*30}Testing Data Size  : {test_ds.cardinality().numpy() * BATCH_SIZE}")
print(f"{' '*30}Validation Data Size  : {valid_ds.cardinality().numpy() * BATCH_SIZE}")
print("*"*100)
****************************************************************************************************
                              Training Data Size : 1024
                              Testing Data Size  : 224
                              Validation Data Size  : 224
****************************************************************************************************
In [11]:
# # Training Data size
# full_train_size = full_train_ds.cardinality().numpy()

# # Split Ratio
# train_val_split = 0.1
# valid_size = int(full_train_size * train_val_split)
# train_size = full_train_size - valid_size

# # Split Data 
# train_ds = full_train_ds.take(train_size)
# valid_ds = full_train_ds.skip(train_size).take(valid_size)
In [12]:
train_ds.cardinality().numpy(),valid_ds.cardinality().numpy(),test_ds.cardinality().numpy()
Out[12]:
(32, 7, 7)
In [13]:
# print("*"*100)
# print(f"{' '*30}Training Data Size   : {train_ds.cardinality().numpy() * BATCH_SIZE}")
# print(f"{' '*30}Validation Data Size : {valid_ds.cardinality().numpy() * BATCH_SIZE}")
# print(f"{' '*30}Testing Data Size    : {test_ds.cardinality().numpy() * BATCH_SIZE}")
# print("*"*100)

Data Visualization¶

In [14]:
def show_images_and_masks(data : tfd.Dataset, n_images: int=10, FIGSIZE: tuple=(25, 5), model: tf.keras.Model=None):
    # Configuration
    if model is None:
        n_cols = 3
    else:
        n_cols = 5
    
    # Collect the data
    images, masks = next(iter(data))
    
    # Iterate over the data
    for n in range(n_images):
        
        # Plotting configuration
        plt.figure(figsize=FIGSIZE)
        
        # Plot the image
        plt.subplot(1, n_cols, 1)
        plt.title("Original Image")
        plt.imshow(images[n])
        plt.axis('off')
        
        # Plot the Mask
        plt.subplot(1, n_cols, 2)
        plt.title("Original Mask")
        plt.imshow(masks[n], cmap='gray')
        plt.axis('off')
        
        # Plot image and mask overlay
        plt.subplot(1, n_cols, 3)
        plt.title('Image and Mask overlay')
        plt.imshow(masks[n], alpha=0.8, cmap='binary_r')
        plt.imshow(images[n], alpha=0.5)
        plt.axis('off')
        
        # Model predictions
        if model is not None:
            pred_mask = model.predict(tf.expand_dims(images[n], axis=0))[0]
            pred_mask = pred_mask>=0.5 # threshold = 0.5
            plt.subplot(1, n_cols, 4)
            plt.title('Predicted Mask')
            plt.imshow(pred_mask, cmap='gray')
            plt.axis('off')
            
            plt.subplot(1, n_cols, 5)
            plt.title('Predicted Mask Overlay')
            plt.imshow(pred_mask, alpha=0.8, cmap='binary_r')
            plt.imshow(images[n], alpha=0.5)
            plt.axis('off')
    
        # Show final plot
        plt.show()

show_images_and_masks(data=train_ds)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

U-Net¶

  • Image segmentation model
  • Encoder and Decoder with skip connection
  • Total: Encoder -16 layers, Decoder - 16 layers
  • Parametes:
No description has been provided for this image

Unet - Encoder Block¶

In [15]:
class EncoderBlock(layers.Layer):
    
    def __init__(self, filters: int, max_pool: bool=True, rate=0.2, **kwargs) -> None:
        super().__init__(**kwargs)
        
        # Params
        self.rate = rate
        self.filters = filters
        self.max_pool = max_pool
        
        # Layers : Initialize the model layers that will be later called
        self.max_pooling = layers.MaxPool2D(pool_size=(2,2), strides=(2,2))
        self.conv1 = layers.Conv2D(
            filters=filters,
            kernel_size=3,
            strides=1,
            padding='same',
            activation='relu',
            kernel_initializer='he_normal'
        )
        self.conv2 = layers.Conv2D(
            filters=filters,
            kernel_size=3,
            strides=1,
            padding='same',
            activation='relu',
            kernel_initializer='he_normal'
        )
        self.drop = layers.Dropout(rate)
        self.bn = layers.BatchNormalization()
        
    def call(self, X, **kwargs):
        
        X = self.bn(X) # BatchNomlarization
        X = self.conv1(X)
        X = self.drop(X)
        X = self.conv2(X)
        
        # Apply Max Pooling if required
        if self.max_pool:
            y = self.max_pooling(X)
            return y, X
        else:
            return X
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'filters': self.filters,
            'max_pool': self.max_pool,
            'rate': self.rate
        })

    def __repr__(self):
        return f"{self.__class__.name}(F={self.filters}, Pooling={self.max_pool})"

UNet - Decoder Block¶

In [16]:
class DecoderBlock(layers.Layer):
    
    def __init__(self, filters: int, rate: float = 0.2, **kwargs):
        super().__init__(**kwargs)
        
        self.filters = filters
        self.rate = rate
        
        # Initialize the model layers
        self.convT = layers.Conv2DTranspose(
            filters = filters,
            kernel_size = 3,
            strides = 2,
            padding = 'same',
            activation = 'relu',
            kernel_initializer = 'he_normal'
        )
        self.bn = layers.BatchNormalization()
        self.net = EncoderBlock(filters = filters, rate = rate, max_pool = False)
        
    def call(self, inputs, **kwargs):
        
        # Get both the inputs
        X, skip_X = inputs
        
        # Up-sample the skip connection
        X = self.bn(X)
        X = self.convT(X)
        
        # Concatenate both inputs
        X = layers.Concatenate(axis=-1)([X, skip_X])
        X = self.net(X)
        
        return X

    def get_config(self):
        config = super().get_config()
        config.update({
            'filters': self.filters,
            'rate': self.rate,
        })
        return config

    def __repr__(self):
        return f"{self.__class__.__name__}(F={self.filters}, rate={self.rate})"

UNet - Encoder Decoder Net¶

In [17]:
# Input Layer
input_layer = layers.Input(shape=(IMAGE_SIZE), name="InputLayer")

# The encoder network
pool1, encoder1 = EncoderBlock(FILTERS,   max_pool=True, rate=0.1, name="EncoderLayer1")(input_layer)
pool2, encoder2 = EncoderBlock(FILTERS*2, max_pool=True, rate=0.1, name="EncoderLayer2")(pool1)
pool3, encoder3 = EncoderBlock(FILTERS*4, max_pool=True, rate=0.2, name="EncoderLayer3")(pool2)
pool4, encoder4 = EncoderBlock(FILTERS*8, max_pool=True, rate=0.2, name="EncoderLayer4")(pool3)

# The encoder encoding
encoding = EncoderBlock(FILTERS*16, max_pool=False, rate=0.3, name="EncodingSpace")(pool4)

# The decoder network
decoder4 = DecoderBlock(FILTERS*8, rate=0.2, name="DecoderLayer1")([encoding, encoder4])
decoder3 = DecoderBlock(FILTERS*4, rate=0.2, name="DecoderLayer2")([decoder4, encoder3])
decoder2 = DecoderBlock(FILTERS*2, rate=0.1, name="DecoderLayer3")([decoder3, encoder2])
decoder1 = DecoderBlock(FILTERS,  rate=0.1, name="DecoderLayer4")([decoder2, encoder1])
        
# Final output layer.
final_conv = layers.Conv2D(
    filters = 1, 
    kernel_size = 1, 
    strides=1, 
    padding='same', 
    activation='sigmoid', 
    name="OutputMap"
)(decoder1)

# Unet Model
unet_model = keras.Model(
    inputs = input_layer,
    outputs = final_conv,
    name = "UNetModel"
)
WARNING:tensorflow:From C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\backend\tensorflow\core.py:192: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'EncodingSpace', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block_1', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block_2', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
C:\Users\Salma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\LocalCache\local-packages\Python312\site-packages\keras\src\layers\layer.py:372: UserWarning: `build()` was called on layer 'encoder_block_3', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
In [18]:
# Model Summary
unet_model.summary()
Model: "UNetModel"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ InputLayer          │ (None, 160, 160,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ EncoderLayer1       │ [(None, 80, 80,   │     10,156 │ InputLayer[0][0]  │
│ (EncoderBlock)      │ 32), (None, 160,  │            │                   │
│                     │ 160, 32)]         │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ EncoderLayer2       │ [(None, 40, 40,   │     55,552 │ EncoderLayer1[0]… │
│ (EncoderBlock)      │ 64), (None, 80,   │            │                   │
│                     │ 80, 64)]          │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ EncoderLayer3       │ [(None, 20, 20,   │    221,696 │ EncoderLayer2[0]… │
│ (EncoderBlock)      │ 128), (None, 40,  │            │                   │
│                     │ 40, 128)]         │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ EncoderLayer4       │ [(None, 10, 10,   │    885,760 │ EncoderLayer3[0]… │
│ (EncoderBlock)      │ 256), (None, 20,  │            │                   │
│                     │ 20, 256)]         │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ EncodingSpace       │ (None, 10, 10,    │  3,540,992 │ EncoderLayer4[0]… │
│ (EncoderBlock)      │ 512)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ DecoderLayer1       │ (None, 20, 20,    │  2,953,984 │ EncodingSpace[0]… │
│ (DecoderBlock)      │ 256)              │            │ EncoderLayer4[0]… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ DecoderLayer2       │ (None, 40, 40,    │    739,712 │ DecoderLayer1[0]… │
│ (DecoderBlock)      │ 128)              │            │ EncoderLayer3[0]… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ DecoderLayer3       │ (None, 80, 80,    │    185,536 │ DecoderLayer2[0]… │
│ (DecoderBlock)      │ 64)               │            │ EncoderLayer2[0]… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ DecoderLayer4       │ (None, 160, 160,  │     46,688 │ DecoderLayer3[0]… │
│ (DecoderBlock)      │ 32)               │            │ EncoderLayer1[0]… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ OutputMap (Conv2D)  │ (None, 160, 160,  │         33 │ DecoderLayer4[0]… │
│                     │ 1)                │            │                   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 8,640,109 (32.96 MB)
 Trainable params: 8,635,303 (32.94 MB)
 Non-trainable params: 4,806 (18.77 KB)
In [19]:
# Inspect trainable and non-trainable parameters
for layer in unet_model.layers:
    print(f"Layer: {layer.name}")
    print(f"  Trainable: {layer.trainable}")
    print(f"  Non-trainable weights: {len(layer.non_trainable_weights)}")
    print(f"  Trainable weights: {len(layer.trainable_weights)}")
Layer: InputLayer
  Trainable: True
  Non-trainable weights: 0
  Trainable weights: 0
Layer: EncoderLayer1
  Trainable: True
  Non-trainable weights: 2
  Trainable weights: 6
Layer: EncoderLayer2
  Trainable: True
  Non-trainable weights: 2
  Trainable weights: 6
Layer: EncoderLayer3
  Trainable: True
  Non-trainable weights: 2
  Trainable weights: 6
Layer: EncoderLayer4
  Trainable: True
  Non-trainable weights: 2
  Trainable weights: 6
Layer: EncodingSpace
  Trainable: True
  Non-trainable weights: 2
  Trainable weights: 6
Layer: DecoderLayer1
  Trainable: True
  Non-trainable weights: 4
  Trainable weights: 10
Layer: DecoderLayer2
  Trainable: True
  Non-trainable weights: 4
  Trainable weights: 10
Layer: DecoderLayer3
  Trainable: True
  Non-trainable weights: 4
  Trainable weights: 10
Layer: DecoderLayer4
  Trainable: True
  Non-trainable weights: 4
  Trainable weights: 10
Layer: OutputMap
  Trainable: True
  Non-trainable weights: 0
  Trainable weights: 2

UNet - Model Training¶

In [20]:
class ShowProgress(callbacks.Callback):
    """A callback that displays the original image, the original mask, 
    the predicted mask, and the Grad-CAM visualization for a sample image 
    after each epoch of training.
    
    Args:
        data (tf.data.Dataset): A dataset of image-mask pairs.
        layer_name (str): The name of the layer to use for Grad-CAM.
        cmap (str, optional): The colormap to use for displaying the masks. 
            Defaults to 'gray'.
        output_dir (str, optional): The directory to save the output images. 
            If None, the images will not be saved. Defaults to None.
        num_images (int, optional): The number of images to display. 
            Defaults to 1.
        file_format (str, optional): The format to save the output images in. 
            Defaults to 'png'.
    """
    def __init__(self, data: tf.data.Dataset, layer_name: str, cmap: str = 'gray', 
                 output_dir: str = None, num_images: int = 1, file_format: str = 'png',
                 **kwargs):
        super().__init__(**kwargs)
        
        # Validate inputs
        if not isinstance(data, tf.data.Dataset):
            raise ValueError('The `data` parameter must be a tf.data.Dataset.')
        if not isinstance(layer_name, str):
            raise ValueError('The `layer_name` parameter must be a string.')
        if not isinstance(num_images, int) or num_images < 1:
            raise ValueError('The `num_images` parameter must be an integer greater than 0.')
        if file_format not in ['png', 'jpg', 'pdf']:
            raise ValueError('The `file_format` parameter must be "png", "jpg", or "pdf".')
        
        self.data = data
        self.layer_name = layer_name
        self.cmap = cmap
        self.output_dir = output_dir
        self.num_images = num_images
        self.file_format = file_format
    
    def on_epoch_end(self, epoch, logs=None):
        """Displays the original image, the original mask, the predicted mask, 
        and the Grad-CAM visualization for a sample image.
        """
        # Plotting configuration
        plt.figure(figsize=(25, 8 * self.num_images))
        
        for i in range(self.num_images):
            # Get Data 
            images, masks = next(iter(self.data))
            images = images.numpy()
            masks = masks.numpy()

            # Select image
            index = np.random.randint(len(images))
            image, mask = images[index], masks[index]

            # Make Prediction
            pred_mask = self.model.predict(np.expand_dims(image, axis=0))[0]

            # Show Image
            plt.subplot(1, 3, 1)
            plt.title("Original Image")
            plt.imshow(image)
            plt.axis('off')

            # Show Mask
            plt.subplot(1, 3, 2)
            plt.title("Original Mask")
            plt.imshow(mask, cmap=self.cmap)
            plt.axis('off')

            # Show Model Pred
            plt.subplot(1, 3, 3)
            plt.title("Predicted Mask")
            plt.imshow(pred_mask, cmap=self.cmap)
            plt.axis('off')

             # Save figure
            if self.output_dir is not None:
                path = os.path.join(os.curdir, self.output_dir)
                plt.savefig(f'Epoch({epoch+1})-Viz.{self.file_format}')

            # Show Final plot
            plt.show()
In [21]:
test_images, test_masks = next(iter(test_ds))

CALLBACKS = [
    callbacks.EarlyStopping(
        patience = 10, 
        restore_best_weights = True),
#     callbacks.ModelCheckpoint(
#         MODEL_NAME + '.h5', 
#         save_best_only = True),
    ShowProgress(
        data = valid_ds,
        layer_name = "DecoderLayer4"
    )
]
In [22]:
def dice_coeff(y_true: tf.Tensor, y_pred: tf.Tensor, smooth: float=1.0) -> tf.Tensor:
    
    """Compute the Dice coefficient between predicted and true masks.

    Args:
        y_true (tf.Tensor): True masks. Shape (batch_size, height, width, num_channels).
        y_pred (tf.Tensor): Predicted masks. Shape (batch_size, height, width, num_channels).
        smooth (float): Smoothing factor to avoid division by zero.

    Returns:
        tf.Tensor: Dice coefficient score.

    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = tf.reduce_mean((2.0 * intersection + smooth) / (union + smooth), axis=0)

    return tf.cast(dice, tf.float32)
In [23]:
# Pixel Accuracy 
pixel_acc = metrics.Accuracy(name="PixelAccuracy")

# Mean Intersection Over Union
mean_iou = metrics.MeanIoU(num_classes=2, name="MeanIoU")

# Exponential learning rate decay
'''
For example,
At Step 0: Learning Rate = 0.001
At Step 500: Learning Rate = 0.001 * 0.96 = 0.00096
At Step 1000: Learning Rate = 0.00096 * 0.96 = 0.0009216
And so on...
'''
initial_learning_rate = BASE_LR
decay_steps = 500  # learning rate will be updated everey 500 steps,
decay_rate = 0.96  # learning rate = learning rate * decay_reate, so lr will decrease 4% 

lr_schedule = ExponentialDecay(
    initial_learning_rate,
    decay_steps,
    decay_rate,
    staircase=True  # learnning rate will drop like step, if staircase=False, it will drop smoothly.
)

optimizer = optimizers.Adam(learning_rate=lr_schedule)

# Compile Model
unet_model.compile(
    loss = 'binary_crossentropy',
    optimizer = optimizer,
    metrics = [
        pixel_acc,
        mean_iou,
        dice_coeff
    ]
)
In [24]:
# Model Training
unet_model_history = unet_model.fit(
    train_ds,
    validation_data = valid_ds,
    epochs = EPOCHS,
    callbacks = CALLBACKS,
    batch_size = BATCH_SIZE,
)
Epoch 1/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 341ms/step- MeanIoU: 0.1953 - PixelAccuracy: 7.7410e-07 - dice_coeff: 0.5757 - loss: 0.73
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.1955 - PixelAccuracy: 7.5642e-07 - dice_coeff: 0.5764 - loss: 0.7291 - val_MeanIoU: 0.2000 - val_PixelAccuracy: 0.3758 - val_dice_coeff: 0.0024 - val_loss: 2211.1560
Epoch 2/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2037 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6165 - loss: 0.536
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2038 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6165 - loss: 0.5363 - val_MeanIoU: 0.2040 - val_PixelAccuracy: 0.3475 - val_dice_coeff: 0.0017 - val_loss: 158.6864
Epoch 3/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2136 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6170 - loss: 0.510
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2134 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6172 - loss: 0.5109 - val_MeanIoU: 0.2069 - val_PixelAccuracy: 0.3203 - val_dice_coeff: 0.0035 - val_loss: 78.7073
Epoch 4/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - MeanIoU: 0.1999 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6504 - loss: 0.478
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2000 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6503 - loss: 0.4788 - val_MeanIoU: 0.2123 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0155 - val_loss: 25.3732
Epoch 5/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - MeanIoU: 0.1975 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6645 - loss: 0.465
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.1975 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6644 - loss: 0.4655 - val_MeanIoU: 0.2029 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0423 - val_loss: 19.0655
Epoch 6/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - MeanIoU: 0.2029 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6370 - loss: 0.470
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2029 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6371 - loss: 0.4714 - val_MeanIoU: 0.2169 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0794 - val_loss: 2.7798
Epoch 7/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2124 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6254 - loss: 0.474
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2122 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6258 - loss: 0.4750 - val_MeanIoU: 0.1785 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0153 - val_loss: 14.4787
Epoch 8/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2053 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6446 - loss: 0.458
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2052 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6445 - loss: 0.4592 - val_MeanIoU: 0.1998 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.0470 - val_loss: 4.3547
Epoch 9/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2048 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6372 - loss: 0.466
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2050 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6370 - loss: 0.4672 - val_MeanIoU: 0.1840 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.1318 - val_loss: 2.5210
Epoch 10/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2081 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6429 - loss: 0.448
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2080 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6432 - loss: 0.4483 - val_MeanIoU: 0.2063 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.2170 - val_loss: 2.6700
Epoch 11/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2113 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6418 - loss: 0.465
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2112 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6419 - loss: 0.4658 - val_MeanIoU: 0.2040 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5801 - val_loss: 0.6076
Epoch 12/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2036 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6475 - loss: 0.450
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2035 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6479 - loss: 0.4503 - val_MeanIoU: 0.2094 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.2787 - val_loss: 2.6653
Epoch 13/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2068 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6562 - loss: 0.449
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2068 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6561 - loss: 0.4489 - val_MeanIoU: 0.2000 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4756 - val_loss: 1.1738
Epoch 14/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - MeanIoU: 0.2105 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6632 - loss: 0.421
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2105 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6629 - loss: 0.4217 - val_MeanIoU: 0.1981 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.2818 - val_loss: 1.7027
Epoch 15/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2014 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6603 - loss: 0.448
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 67s 2s/step - MeanIoU: 0.2015 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6602 - loss: 0.4490 - val_MeanIoU: 0.2160 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4432 - val_loss: 1.3020
Epoch 16/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2137 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6348 - loss: 0.457
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 67s 2s/step - MeanIoU: 0.2133 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6357 - loss: 0.4576 - val_MeanIoU: 0.2224 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5913 - val_loss: 0.7013
Epoch 17/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2055 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6704 - loss: 0.423
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2057 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6702 - loss: 0.4240 - val_MeanIoU: 0.2235 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.3283 - val_loss: 1.2856
Epoch 18/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - MeanIoU: 0.2026 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6732 - loss: 0.402
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 68s 2s/step - MeanIoU: 0.2026 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6733 - loss: 0.4029 - val_MeanIoU: 0.1886 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4122 - val_loss: 1.3875
Epoch 19/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - MeanIoU: 0.2131 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6556 - loss: 0.426
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2129 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6562 - loss: 0.4263 - val_MeanIoU: 0.2155 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5022 - val_loss: 0.8983
Epoch 20/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2023 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6678 - loss: 0.416
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2025 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6679 - loss: 0.4163 - val_MeanIoU: 0.2131 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6230 - val_loss: 0.5877
Epoch 21/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2027 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6764 - loss: 0.405
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2028 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6762 - loss: 0.4058 - val_MeanIoU: 0.2200 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.4770 - val_loss: 0.8206
Epoch 22/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.2135 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6626 - loss: 0.416
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2131 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6630 - loss: 0.4163 - val_MeanIoU: 0.1916 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7351 - val_loss: 0.5219
Epoch 23/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - MeanIoU: 0.2094 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6714 - loss: 0.411
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2093 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6715 - loss: 0.4117 - val_MeanIoU: 0.1996 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7330 - val_loss: 0.6408
Epoch 24/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - MeanIoU: 0.2061 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6814 - loss: 0.399
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2062 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6810 - loss: 0.4002 - val_MeanIoU: 0.1916 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6788 - val_loss: 0.4418
Epoch 25/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - MeanIoU: 0.2006 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6744 - loss: 0.417
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2008 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6743 - loss: 0.4173 - val_MeanIoU: 0.2134 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5442 - val_loss: 0.5786
Epoch 26/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - MeanIoU: 0.2001 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6675 - loss: 0.425
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2001 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6678 - loss: 0.4251 - val_MeanIoU: 0.2055 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6549 - val_loss: 0.4363
Epoch 27/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - MeanIoU: 0.2041 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6863 - loss: 0.386
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 64s 2s/step - MeanIoU: 0.2041 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6861 - loss: 0.3869 - val_MeanIoU: 0.2087 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.5097 - val_loss: 0.7656
Epoch 28/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - MeanIoU: 0.1990 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6716 - loss: 0.393
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.1991 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6715 - loss: 0.3941 - val_MeanIoU: 0.2002 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6198 - val_loss: 0.4951
Epoch 29/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - MeanIoU: 0.2024 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6895 - loss: 0.379
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2024 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6892 - loss: 0.3796 - val_MeanIoU: 0.2229 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6704 - val_loss: 0.4801
Epoch 30/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - MeanIoU: 0.2027 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6837 - loss: 0.393
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2027 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6836 - loss: 0.3935 - val_MeanIoU: 0.1855 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6057 - val_loss: 0.5543
Epoch 31/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.1981 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6964 - loss: 0.386
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1981 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6963 - loss: 0.3861 - val_MeanIoU: 0.1994 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7256 - val_loss: 0.4344
Epoch 32/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - MeanIoU: 0.2029 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6942 - loss: 0.389
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2028 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6941 - loss: 0.3901 - val_MeanIoU: 0.2168 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6799 - val_loss: 0.4103
Epoch 33/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - MeanIoU: 0.2034 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6679 - loss: 0.399
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2035 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6681 - loss: 0.3993 - val_MeanIoU: 0.2016 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7077 - val_loss: 0.4347
Epoch 34/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.2018 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6803 - loss: 0.400
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2018 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6804 - loss: 0.4009 - val_MeanIoU: 0.2158 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6156 - val_loss: 0.4675
Epoch 35/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - MeanIoU: 0.2028 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6906 - loss: 0.373
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2030 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6902 - loss: 0.3739 - val_MeanIoU: 0.2052 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7184 - val_loss: 0.3623
Epoch 36/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - MeanIoU: 0.2006 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7026 - loss: 0.363
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 66s 2s/step - MeanIoU: 0.2007 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7023 - loss: 0.3637 - val_MeanIoU: 0.2227 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6607 - val_loss: 0.3603
Epoch 37/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - MeanIoU: 0.1988 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6915 - loss: 0.381
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1989 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6914 - loss: 0.3816 - val_MeanIoU: 0.2118 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6805 - val_loss: 0.3887
Epoch 38/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - MeanIoU: 0.1990 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7049 - loss: 0.361
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1992 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7044 - loss: 0.3625 - val_MeanIoU: 0.2003 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7084 - val_loss: 0.4957
Epoch 39/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - MeanIoU: 0.2047 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6970 - loss: 0.366
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2047 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6970 - loss: 0.3665 - val_MeanIoU: 0.1988 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6954 - val_loss: 0.3742
Epoch 40/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2096 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6706 - loss: 0.401
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2094 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6710 - loss: 0.4015 - val_MeanIoU: 0.2065 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7359 - val_loss: 0.5942
Epoch 41/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.2073 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6944 - loss: 0.357
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2072 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6944 - loss: 0.3576 - val_MeanIoU: 0.2351 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6398 - val_loss: 0.4517
Epoch 42/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - MeanIoU: 0.2052 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6878 - loss: 0.369
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2053 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6874 - loss: 0.3703 - val_MeanIoU: 0.2146 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6190 - val_loss: 0.4801
Epoch 43/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - MeanIoU: 0.2019 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6970 - loss: 0.374
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2020 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6968 - loss: 0.3748 - val_MeanIoU: 0.2058 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7089 - val_loss: 0.3656
Epoch 44/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - MeanIoU: 0.2056 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6960 - loss: 0.357
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2055 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6963 - loss: 0.3576 - val_MeanIoU: 0.2261 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6576 - val_loss: 0.4152
Epoch 45/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - MeanIoU: 0.2048 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6932 - loss: 0.357
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.2049 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6930 - loss: 0.3579 - val_MeanIoU: 0.2119 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.6908 - val_loss: 0.4689
Epoch 46/100
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - MeanIoU: 0.1956 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7229 - loss: 0.348
No description has been provided for this image
32/32 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - MeanIoU: 0.1959 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.7220 - loss: 0.3494 - val_MeanIoU: 0.2000 - val_PixelAccuracy: 0.0000e+00 - val_dice_coeff: 0.7275 - val_loss: 0.3793

Model Learning Curve¶

In [25]:
# Model History
history = unet_model_history.history
history
Out[25]:
{'MeanIoU': [0.1992032825946808,
  0.20723021030426025,
  0.2089681178331375,
  0.20103561878204346,
  0.19850480556488037,
  0.20233049988746643,
  0.20658963918685913,
  0.2036864310503006,
  0.20984354615211487,
  0.2046331912279129,
  0.20716583728790283,
  0.20179182291030884,
  0.20636393129825592,
  0.2093118131160736,
  0.20593519508838654,
  0.2030024528503418,
  0.21003907918930054,
  0.20242716372013092,
  0.20756644010543823,
  0.20600050687789917,
  0.20601294934749603,
  0.20151656866073608,
  0.20484960079193115,
  0.20811578631401062,
  0.20776106417179108,
  0.19975927472114563,
  0.20528320968151093,
  0.2022167146205902,
  0.20413976907730103,
  0.20312047004699707,
  0.20042486488819122,
  0.2013271152973175,
  0.20636247098445892,
  0.20346450805664062,
  0.20794083178043365,
  0.20383527874946594,
  0.20293761789798737,
  0.203046053647995,
  0.2042224109172821,
  0.20371179282665253,
  0.20469354093074799,
  0.2084343433380127,
  0.2064168006181717,
  0.20335863530635834,
  0.20726317167282104,
  0.20414483547210693],
 'PixelAccuracy': [1.9073485191256623e-07,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 'dice_coeff': [0.5992616415023804,
  0.6143656373023987,
  0.6241587400436401,
  0.647817075252533,
  0.6599338054656982,
  0.6411471366882324,
  0.6401103734970093,
  0.6434901356697083,
  0.6323947906494141,
  0.6539073586463928,
  0.6431247591972351,
  0.6608415842056274,
  0.6543480157852173,
  0.6542718410491943,
  0.6568323969841003,
  0.6639542579650879,
  0.6646468043327332,
  0.6780006289482117,
  0.6735907196998596,
  0.6692067384719849,
  0.671150267124176,
  0.6757012009620667,
  0.6747865080833435,
  0.6676671504974365,
  0.6718212366104126,
  0.6775573492050171,
  0.6802664995193481,
  0.6669780611991882,
  0.6803557872772217,
  0.6801664233207703,
  0.6923991441726685,
  0.6913303136825562,
  0.6717274188995361,
  0.6860343813896179,
  0.6770738959312439,
  0.6927362680435181,
  0.6867203116416931,
  0.6892097592353821,
  0.696695864200592,
  0.6829046607017517,
  0.6935202479362488,
  0.6724157929420471,
  0.6902781128883362,
  0.7051346898078918,
  0.6867726445198059,
  0.6933310031890869],
 'loss': [0.6324572563171387,
  0.5349217653274536,
  0.5111037492752075,
  0.4871182441711426,
  0.4752787947654724,
  0.4861675202846527,
  0.4822116494178772,
  0.4696265757083893,
  0.47592490911483765,
  0.4478815197944641,
  0.4760293960571289,
  0.4480341970920563,
  0.4450124204158783,
  0.44162890315055847,
  0.45347458124160767,
  0.44626083970069885,
  0.42650797963142395,
  0.40765997767448425,
  0.40924662351608276,
  0.4210212826728821,
  0.40673133730888367,
  0.41711413860321045,
  0.4102461338043213,
  0.4149117171764374,
  0.421040415763855,
  0.4233340620994568,
  0.39630216360092163,
  0.424282968044281,
  0.39938491582870483,
  0.4032699167728424,
  0.38839882612228394,
  0.3984592854976654,
  0.401938796043396,
  0.398421049118042,
  0.3912218511104584,
  0.38229379057884216,
  0.38274428248405457,
  0.38439831137657166,
  0.37204596400260925,
  0.4020407497882843,
  0.37592634558677673,
  0.40169450640678406,
  0.3803367614746094,
  0.35325393080711365,
  0.3741324841976166,
  0.3742380738258362],
 'val_MeanIoU': [0.19999520480632782,
  0.20395812392234802,
  0.20691031217575073,
  0.21234601736068726,
  0.20294633507728577,
  0.21694649755954742,
  0.17849156260490417,
  0.19981227815151215,
  0.18399962782859802,
  0.20628155767917633,
  0.20395368337631226,
  0.20939287543296814,
  0.20003190636634827,
  0.198095440864563,
  0.2160402089357376,
  0.22242309153079987,
  0.22348520159721375,
  0.1886315941810608,
  0.21554818749427795,
  0.21313956379890442,
  0.22003304958343506,
  0.19160348176956177,
  0.1996234953403473,
  0.19158752262592316,
  0.21338379383087158,
  0.20552490651607513,
  0.20871075987815857,
  0.20019182562828064,
  0.22285757958889008,
  0.18552769720554352,
  0.19937221705913544,
  0.216833233833313,
  0.2015785425901413,
  0.21578744053840637,
  0.20519505441188812,
  0.22265677154064178,
  0.2117667943239212,
  0.20026306807994843,
  0.19879455864429474,
  0.2064943164587021,
  0.23507864773273468,
  0.214645653963089,
  0.20577462017536163,
  0.22608904540538788,
  0.2119479775428772,
  0.1999789923429489],
 'val_PixelAccuracy': [0.3757653534412384,
  0.3474501073360443,
  0.3203323781490326,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 'val_dice_coeff': [0.002381600672379136,
  0.0017498626839369535,
  0.0034958492033183575,
  0.015481376089155674,
  0.042307049036026,
  0.07937664538621902,
  0.015303398482501507,
  0.04697667807340622,
  0.13179059326648712,
  0.21703331172466278,
  0.5800686478614807,
  0.2786713242530823,
  0.47563308477401733,
  0.28178802132606506,
  0.4432109296321869,
  0.5913054347038269,
  0.3282725512981415,
  0.41223838925361633,
  0.5021713376045227,
  0.6229732632637024,
  0.4769725501537323,
  0.7351338267326355,
  0.7330007553100586,
  0.678837239742279,
  0.5442424416542053,
  0.6548649668693542,
  0.5096802115440369,
  0.6198228597640991,
  0.6704198122024536,
  0.6056585907936096,
  0.7256316542625427,
  0.6798519492149353,
  0.70768803358078,
  0.615635871887207,
  0.7184128761291504,
  0.6607317924499512,
  0.6804813742637634,
  0.7083863019943237,
  0.6953719258308411,
  0.7359205484390259,
  0.6398470997810364,
  0.6189655065536499,
  0.7089042067527771,
  0.6576145887374878,
  0.690753161907196,
  0.727544903755188],
 'val_loss': [2211.156005859375,
  158.6863555908203,
  78.70732116699219,
  25.373165130615234,
  19.065534591674805,
  2.7797865867614746,
  14.478734970092773,
  4.354736328125,
  2.5210089683532715,
  2.6700148582458496,
  0.6075860857963562,
  2.665254592895508,
  1.173814058303833,
  1.7027209997177124,
  1.3020278215408325,
  0.7013348937034607,
  1.285564661026001,
  1.3875261545181274,
  0.8983238935470581,
  0.5877137184143066,
  0.8205947875976562,
  0.521854043006897,
  0.6408008933067322,
  0.4418080747127533,
  0.5785664319992065,
  0.4362673759460449,
  0.7655749917030334,
  0.4951431155204773,
  0.48009970784187317,
  0.5542751550674438,
  0.434432715177536,
  0.4103431701660156,
  0.4347221553325653,
  0.4674955904483795,
  0.36232373118400574,
  0.3602845370769501,
  0.3887442946434021,
  0.49567872285842896,
  0.37423327565193176,
  0.594249427318573,
  0.45169612765312195,
  0.4800850450992584,
  0.36560124158859253,
  0.41520828008651733,
  0.46888765692710876,
  0.37928253412246704]}

Model Predictions¶

In [26]:
show_images_and_masks(data=train_ds, model=unet_model)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 46ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step
No description has been provided for this image
In [27]:
show_images_and_masks(data=valid_ds, model=unet_model)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 52ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 47ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 50ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 49ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 85ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step
No description has been provided for this image
In [28]:
show_images_and_masks(data=test_ds, model=unet_model)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 48ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 47ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 48ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step
No description has been provided for this image
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step
No description has been provided for this image

Evaluation¶

In [29]:
# U-Net test accuracy 
results = unet_model.evaluate(test_ds, verbose=1)
# print(f'Test Loss: {test_loss:.4f}')
# print(f'Test Accuracy: {test_accuracy:.4f}')
# print(f'Test IoU: {test_iou:.4f}')
7/7 ━━━━━━━━━━━━━━━━━━━━ 3s 491ms/step - MeanIoU: 0.2065 - PixelAccuracy: 0.0000e+00 - dice_coeff: 0.6806 - loss: 0.3672

Show train and validation loss¶

In [30]:
epochs = range(1, len(unet_model_history.history["loss"]) + 1)
loss = unet_model_history.history["loss"]
val_loss = unet_model_history.history["val_loss"]
plt.figure()
plt.plot(epochs, loss, "r", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
Out[30]:
<matplotlib.legend.Legend at 0x161b67cdc10>
No description has been provided for this image
In [32]:
# Extract the history object from the training
history = unet_model_history.history

# Define the list of metrics you want to plot
evaluation_metrics = ['PixelAccuracy', 'MeanIoU', 'dice_coeff']

# Create subplots for each metric
plt.figure(figsize=(18, 6))

for i, metric in enumerate(evaluation_metrics):
    plt.subplot(1, len(evaluation_metrics), i + 1)
    plt.plot(history[metric], label='Train')
    plt.plot(history[f'val_{metric}'], label='Validation')
    plt.title(f'{metric} over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel(metric)
    plt.legend()

plt.tight_layout()
plt.show()
No description has been provided for this image

SegNet (A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation)¶

No description has been provided for this image
In [ ]:
# #import tensorflow as tf
# from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, BatchNormalization, ReLU
# from tensorflow.keras.models import Model
# #from tensorflow.keras.optimizers import Adam
# #from tensorflow.keras.losses import BinaryCrossentropy
# #from tensorflow.keras.metrics import BinaryAccuracy, MeanIoU
# from tensorflow.keras import metrics 


# class SegNet:
#     def __init__(self, input_shape, num_classes=1, base_lr=0.001):
#         self.input_shape = input_shape
#         self.num_classes = num_classes
#         self.base_lr = base_lr
#         self.model = self.build_model()
#         self.compile()

#     def conv_block(self, x, filters, kernel_size=3, strides=1, padding='same'):
#         x = Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=padding)(x)
#         x = BatchNormalization()(x)
#         x = ReLU()(x)
#         return x

#     def encoder_block(self, x, filters):
#         x = self.conv_block(x, filters)
#         x = self.conv_block(x, filters)
#         p = MaxPooling2D(pool_size=2, strides=2, padding='same')(x)
#         return x, p

#     def decoder_block(self, x, skip, filters):
#         x = Conv2DTranspose(filters, kernel_size=3, strides=2, padding='same')(x)
#         x = tf.concat([x, skip], axis=-1)
#         x = self.conv_block(x, filters)
#         x = self.conv_block(x, filters)
#         return x

#     def build_model(self):
#         inputs = Input(shape=self.input_shape)

#         # Encoder
#         e1, p1 = self.encoder_block(inputs, 64)
#         e2, p2 = self.encoder_block(p1, 128)
#         e3, p3 = self.encoder_block(p2, 256)
#         e4, p4 = self.encoder_block(p3, 512)
#         e5, p5 = self.encoder_block(p4, 512)

#         # Decoder
#         d5 = self.decoder_block(p5, e5, 512)
#         d4 = self.decoder_block(d5, e4, 512)
#         d3 = self.decoder_block(d4, e3, 256)
#         d2 = self.decoder_block(d3, e2, 128)
#         d1 = self.decoder_block(d2, e1, 64)

#         # Final layer for binary segmentation
#         outputs = Conv2D(self.num_classes, kernel_size=1, activation='sigmoid')(d1)

#         model = Model(inputs=inputs, outputs=outputs)
#         return model

#     def compile(self):
#         # Define metrics
#         pixel_acc = metrics.Accuracy(name="PixelAccuracy")
#         mean_iou = metrics.MeanIoU(num_classes=2, name="MeanIoU")

#         # Learning Rate Schedule
#         lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
#             initial_learning_rate=self.base_lr,
#             decay_steps=500,
#             decay_rate=0.96,
#             staircase=True
#         )
        
#         # Optimizer
#         optimizer = optimizers.Adam(learning_rate=lr_schedule)

#         # Compile the model
#         self.model.compile(
#             loss='binary_crossentropy',
#             optimizer=optimizer,
#             metrics=[pixel_acc, mean_iou, dice_coeff]
#         )

#     def summary(self):
#         self.model.summary()
In [ ]:
# input_shape = (IMAGE_WIDTH, IMAGE_HEIGHT, N_IMAGE_CHANNELS)  
# num_classes = 1  # Single channel output for binary segmentation

# # Create the SegNet model
# segnet = SegNet(input_shape, num_classes)

# # Compile the model
# segnet.compile()

# # Print model summary
# segnet.summary()
In [ ]:
# # Model Training
# segnet_model_history = segnet.model.fit(
#     train_ds,
#     validation_data = valid_ds,
#     epochs = EPOCHS,
#     callbacks = CALLBACKS,
#     batch_size = BATCH_SIZE,
# )
In [ ]: